-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Fix NaN value comparisons in relu, max and min ops #14262
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this? Doesn't we already have a is_nan() for all operators?
@apeforest we need it because operations like this And
|
@anirudhacharya I'm not sure if it answers the question. Why do we need to start to support nan values, especially given the extra handling required for nan. |
@szha I ideally do not want Pytorch's
Also I found a related issue here - #14157 Edit - Another issue filed some time ago which had slipped from my memory - #11115 |
@mxnet-label-bot add [pr-awaiting-review] |
@anirudhacharya thanks for the explanation. should relu grad deal with nan in a special way? |
@szha yes I think the relu grad should also be handled in a special way, thanks for pointing it out. Currently relu grad at
But FYI - Here is an in depth conversation on |
Nice PR! I also had a bug in my model, and because of relu activations removing NaNs it took me much longer to realize there is a bug. Behaviour should definitely be changed! |
8e1bd80
to
43efb35
Compare
@anirudhacharya I'm not sure if relu grad should act like that. As a sanity check, consider if nan is larger than or smaller than 0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems that nan should be surfaced in relu grad instead of 1 when output is nan, because nan is not a number.
Could we add a new operator to check whether there are nan? It is used when debug. |
nan compared to any number is always Ref - https://stackoverflow.com/questions/49011370/nan-propagation-and-ieee-754-standard/49040225 But there are languages and libraries which consider nan to be greater than any number even np.inf. |
@anirudhacharya it's not about comparison. nan is not in the domain of the function. |
@szha What you say makes sense, nan is not a number and hence not in the realm of comparison. So any occurrence either forward or backward will have to be propagated. Maybe pytorch is doing it wrong, but just for comparison's sake pytorch seems to treat the gradient of relu @ NaN as equal to
My main motivation when I first made changes to the relu forward behavior was that the operator silently clipping NaN values was very misleading while trying to build or debug models. I am open to suggestions on how relu gradient should behave, it would seem there is no single consensus on this and each community/library decide things for themselves . |
I think you are looking for this - http://mxnet.incubator.apache.org/api/python/ndarray/contrib.html?highlight=isnan#mxnet.ndarray.contrib.isnan |
95ce922
to
7dbf2bb
Compare
I modified relu grad to also propagate |
@anirudhacharya one last thing, could you measure the performance before and after this change? This change is nonetheless necessary, still it would better if we could anticipate any performance change from this. Thanks. |
Run Mode: Before --> After ( time in ms)'Whole CPU run: ' - 0.843163 --> 0.864071 script used import mxnet as mx
import numpy as np
from mxnet.test_utils import check_speed
ctx = mx.cpu()
#ctx = mx.gpu(0)
sample_data = mx.nd.ones((3, 500, 500), ctx=ctx)
sample_data[0] = -1.
sample_data[1] = np.NaN
sample = mx.sym.Variable("sample")
relu_sym = mx.sym.relu(data=sample)
print("Whole CPU run: ", check_speed(relu_sym, location={"sample": sample_data}, ctx=ctx, N=int(1e5), typ="whole"))
print("Forward CPU run: ", check_speed(relu_sym, location={"sample": sample_data}, ctx=ctx, N=int(1e5), typ="forward"))
#print("Whole GPU run: ", check_speed(relu_sym, location={"sample": sample_data}, ctx=ctx, N=int(1e5), typ="whole"))
#print("Forward GPU run: ", check_speed(relu_sym, location={"sample": sample_data}, ctx=ctx, N=int(1e5), typ="forward")) |
* nan comparison * fix relu grad
* nan comparison * fix relu grad
* nan comparison * fix relu grad
Description
Fix NaN comparisons in relu, max and min ops
Fixes #14157
Fixes #11115
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
@anirudh2290 @apeforest